from grid_env_ideal_obs_repeat_task import *
from grid_agent import *
from checkpoint_utils import *
from maze_factory import *
from replay_config import *
import argparse
import json
import sys
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.lines import Line2D
from sklearn.manifold import TSNE
import random
from sklearn.decomposition import PCA
from matplotlib.animation import FuncAnimation
from sklearn.cluster import KMeans
import threading
import mplcursors
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

def progress_bar(current, total, barLength = 100):
    percent = float(current) * 100 / total
    arrow = '-' * int(percent/100 * barLength - 1) + '>'
    spaces = ' ' * (barLength - len(arrow))

    print('Progress: [%s%s] %d %%' % (arrow, spaces, percent), end='\r')
    sys.stdout.flush()

@partial(jax.jit, static_argnums=(3,))
def model_forward(variables, state, x, model):
    """ forward pass of the model
    """
    return model.apply(variables, state, x)

@jit
def get_action(y):
    return jnp.argmax(y)
get_action_vmap = jax.vmap(get_action)

# load landscape and states from file
def load_task(pth = "./logs/task.json"):
    # open json file
    with open(pth, "r") as f:
        data = json.load(f)
        landscape = data["data"]
        state = data["state"]
        goal = data["goal"]
        print("state: ", state)
        print("goal: ", goal)
        print("landscape: ", landscape)
    return landscape, state, goal

# save current landscape as json file
def save_current_task(landscape, start_x, start_y, goal_x, goal_y, pth = "./logs/landscape.json"):
    landscape_ = []
    for j in range(landscape[0].shape[0]):
        landscape_.append(int(landscape[0][j]))

    with open(pth, "w") as f:
        json.dump({"data": landscape_, 
                   "state": [start_x, start_y],
                   "goal": [goal_x, goal_y]}, f)

def render(grid, state, goal, valid = True):
        
        state_x = int(state[0])
        state_y = int(state[1])

        food_x = int(goal[0])
        food_y = int(goal[1])

        grid_size_display = 20
        width, height = grid.shape[0], grid.shape[1]
        img = np.zeros((width * grid_size_display, height * grid_size_display, 3), np.uint8)
        
        for j in range(width):
            for i in range(height):
                if grid[j,i] == 1:
                    cv2.rectangle(img, (i * grid_size_display, j * grid_size_display), (i * grid_size_display + grid_size_display, j * grid_size_display + grid_size_display), (255, 255, 255), -1)
                    # draw border with color(100,100,100)
                    cv2.rectangle(img, (i * grid_size_display, j * grid_size_display), (i * grid_size_display + grid_size_display, j * grid_size_display + grid_size_display), (100, 100, 100), 1)
                else:
                    cv2.rectangle(img, (i * grid_size_display, j * grid_size_display), (i * grid_size_display + grid_size_display, j * grid_size_display + grid_size_display), (0, 0, 0), -1)
                    # draw border with color(100,100,100)
                    cv2.rectangle(img, (i * grid_size_display, j * grid_size_display), (i * grid_size_display + grid_size_display, j * grid_size_display + grid_size_display), (100, 100, 100), 1)
                if j == state_x and i == state_y:
                    cv2.circle(img, (i * grid_size_display + int(grid_size_display/2), j * grid_size_display + int(grid_size_display/2)), 7, (0, 0, 255), -1, cv2.LINE_AA)
        
        # put with a dot on food position
        cv2.circle(img, (food_y * grid_size_display + grid_size_display//2, food_x * grid_size_display + grid_size_display//2), 7, (0,100,0), -1, cv2.LINE_AA)

        # put with a dot on food position
        cv2.circle(img, (food_y * grid_size_display + grid_size_display//2, food_x * grid_size_display + grid_size_display//2), 7, (0,100,0), -1, cv2.LINE_AA)

        if not valid:
            # draw a big red cross
            cv2.line(img, (0, 0), (img.shape[1], img.shape[0]), (0, 0, 255), 5, cv2.LINE_AA)
            cv2.line(img, (0, img.shape[0]), (img.shape[1], 0), (0, 0, 255), 5, cv2.LINE_AA)
            cv2.putText(img, "invalid map", (int(img.shape[1]/2) - 100, int(img.shape[0]/2)), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (200, 200, 200), 3, cv2.LINE_AA)
            cv2.putText(img, "invalid map", (int(img.shape[1]/2) - 100, int(img.shape[0]/2)), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (200, 0, 0), 2, cv2.LINE_AA)
        return img

event_type = ""
event_x = 0
event_y = 0

def run_editor(landscape, state, goal, map_size=12):

    global event_type, event_x, event_y

    # define mouse callback function
    def input_cb(event,x,y,flags,param):
        global event_type, event_x, event_y
        if event == cv2.EVENT_MOUSEWHEEL:
            event_x = x
            event_y = y
            event_type = "flip_space"
        elif event == cv2.EVENT_LBUTTONUP:
            event_x = x
            event_y = y
            event_type = "set_start"
        elif event == cv2.EVENT_RBUTTONUP:
            event_x = x
            event_y = y
            event_type = "set_goal"

    cv2.namedWindow("img", cv2.WINDOW_GUI_NORMAL)
    cv2.setMouseCallback("img", input_cb)

    grid = np.array(landscape).reshape(map_size, map_size).transpose()
    valid = True
    grid_size_display = 20

    while True:

        if event_type == "flip_space":
            grid[event_y//grid_size_display, event_x//grid_size_display] = 1 - grid[event_y//grid_size_display, event_x//grid_size_display]
            event_type = "flip_space_done"
        elif event_type == "set_start":
            state = [event_y//grid_size_display, event_x//grid_size_display]
            event_type = "set_start_done"
        elif event_type == "set_goal":
            goal = [event_y//grid_size_display, event_x//grid_size_display]
            event_type = "set_goal_done"

        num_labels, labels, stats, centroids, num_freespace, landscape_img = check_num_labels(grid, map_size, map_size)
        non_zeros = np.count_nonzero(landscape)
        if not (num_labels == 2 and non_zeros >= 5):
            valid = False
        else:
            valid = True
        
        img = render(grid, state, goal, valid)
        cv2.imshow("img", img)
        k = cv2.waitKey(1)
        if k == ord('q'):
            break
        elif k == ord('s'):
            pth = "./logs/test.json"
            grid0 = grid.transpose()
            grid1 = grid0.reshape(map_size*map_size)
            save_current_task([grid1], state[0], state[1], goal[0], goal[1], pth)
            print("task saved to {}".format(pth))
        elif k == ord('r'):
            grid = np.array(landscape).reshape(map_size, map_size).transpose()
    
    grid_ = grid.transpose()
    grid_ = grid_.reshape(map_size*map_size).tolist()
    return grid_


def main():

    """ parse arguments
    """
    rpl_config = ReplayConfig()

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_pth", type=str, default=rpl_config.model_pth)
    parser.add_argument("--map_size", type=int, default=rpl_config.map_size)
    parser.add_argument("--task_pth", type=str, default=rpl_config.task_pth)
    parser.add_argument("--log_pth", type=str, default=rpl_config.log_pth)
    parser.add_argument("--nn_size", type=int, default=rpl_config.nn_size)
    parser.add_argument("--nn_type", type=str, default=rpl_config.nn_type)
    parser.add_argument("--show_kf", type=str, default=rpl_config.show_kf)
    parser.add_argument("--visualization", type=str, default=rpl_config.visualization)
    parser.add_argument("--video_output", type=str, default=rpl_config.video_output)
    parser.add_argument("--life_duration", type=int, default=rpl_config.life_duration)

    args = parser.parse_args()

    rpl_config.model_pth = args.model_pth
    rpl_config.map_size = args.map_size
    rpl_config.task_pth = args.task_pth
    rpl_config.log_pth = args.log_pth
    rpl_config.nn_size = args.nn_size
    rpl_config.nn_type = args.nn_type
    rpl_config.show_kf = args.show_kf
    rpl_config.visualization = args.visualization
    rpl_config.video_output = args.video_output
    rpl_config.life_duration = args.life_duration

    k1 = jax.random.PRNGKey(npr.randint(0, 1000000))

    """ load model
    """
    params = load_weights(rpl_config.model_pth)

    # get elements of params
    tree_leaves = jax.tree_util.tree_leaves(params)
    for i in range(len(tree_leaves)):
        print("shape of leaf ", i, ": ", tree_leaves[i].shape)

    bias1 = np.array(tree_leaves[0])
    mat1 = np.array(tree_leaves[1])
    print("mat1.shape: ", mat1.shape)

    mat_obs = np.array(tree_leaves[1])[rpl_config.nn_size:rpl_config.nn_size+10,:]
    mat_intr = np.array(tree_leaves[1])[0:rpl_config.nn_size,:]
    print("mat_obs.shape: ", mat_obs.shape)


    """ create landscape
    """
    random_task = True
    # check if file on rpl_config.task_pth exists
    if os.path.isfile(rpl_config.task_pth):
        random_task = False

    if random_task:
        landscape = generate_maze_pool(num_mazes=1, width=10, height=10)
        landscape = padding_landscapes(landscape, width=12, height=12)
    else:
        landscape, state, goal = load_task(pth = rpl_config.task_pth)
        landscape = [landscape]

    print("landscape :")
    print(landscape)

    """ create agent
    """
    if rpl_config.nn_type == "vanilla":
        model = RNN(hidden_dims = rpl_config.nn_size)
    elif rpl_config.nn_type == "gru":
        model = GRU(hidden_dims = rpl_config.nn_size)

    # check if param fits the agent
    if rpl_config.nn_type == "vanilla":
        assert params["params"]["Dense_0"]["kernel"].shape[0] == rpl_config.nn_size + 10

    """ create grid env
    """
    start_time = time.time()
    GE = GridEnv(landscapes = landscape, width = 12, height = 12, num_envs_per_landscape = 1, reward_free=True)
    GE.reset()
    print("time taken to create envs: ", time.time() - start_time)

    if not random_task:
        # set states of GE
        GE.batched_states = GE.batched_states.at[0, 0].set(state[0])
        GE.batched_states = GE.batched_states.at[0, 1].set(state[1])
        # set goals of GE
        GE.batched_goals = GE.batched_goals.at[0, 0].set(goal[0])
        GE.batched_goals = GE.batched_goals.at[0, 1].set(goal[1])
        GE.init_batched_states, GE.init_batched_goals = jnp.copy(GE.batched_states), jnp.copy(GE.batched_goals)
        GE.batched_goal_reached = batch_compute_goal_reached(GE.batched_states, GE.batched_goals)
        GE.last_batched_goal_reached = jnp.copy(GE.batched_goal_reached)
        GE.concat_obs = get_ideal_obs_vmap(GE.batched_envs, GE.batched_states, GE.batched_goals, GE.last_batched_goal_reached)

    concat_obs = GE.concat_obs

    rnn_state = model.initial_state(GE.num_envs)

    step_count = 0
    render_id = 0

    trajectory = []

    goal_record = []
    HS_trajectory = []
    IPFs = []
    intr_field = []

    reset_ = True

    step_by_step = False
    manual_action = 0

    alter_t = 0

    first_arrival = 0
    
    for t in range(rpl_config.life_duration):

        progress_bar(t, rpl_config.life_duration)

        step_count += 1

        # 将rnn_state[0]和concat_obs[0]拼接成一个新的向量
        new_vector = jnp.concatenate((rnn_state[0], concat_obs[0]))
        # 将新向量与mat1相乘，得到长度为128的向量
        result_vector = jnp.dot(new_vector, mat1) + bias1
        result_vector = jnp.tanh(result_vector)
        intr_vector = jnp.dot(rnn_state[0], mat_intr) + bias1
        intr_vector = jnp.tanh(intr_vector)

        IPF = result_vector - intr_vector
        IPFs.append(IPF)
        intr_field.append(intr_vector - rnn_state[0])

        HS_trajectory.append(np.array(rnn_state[0]))
        goal_record.append(GE.batched_goal_reached[0])

        """ model forward and step the env
        """
        rnn_state, y1 = model_forward(params, rnn_state, concat_obs, model)
        
        if not step_by_step:
            batched_actions = get_action_vmap(y1)
        else:
            batched_actions = jnp.array([manual_action])

        batched_goal_reached, concat_obs = GE.step(batched_actions, reset = reset_)

        """ render the env
        """
        if rpl_config.visualization == "True" or rpl_config.video_output == "True":
            img = GE.render(env_id = render_id)
            if len(trajectory) > 1:
                for i in range(len(trajectory)-1):
                    cv2.line(img, (int(trajectory[i][1]), int(trajectory[i][0])), (int(trajectory[i+1][1]), int(trajectory[i+1][0])), (0,130,0), 2)
            
        trajectory.append([20 * GE.batched_states[render_id][0]+10, 20 * GE.batched_states[render_id][1]+10])

        if batched_goal_reached[render_id]:
            trajectory.clear()
        
        if batched_goal_reached[render_id] and first_arrival == 0:
            first_arrival = 1

        """ scene display
        """
        if rpl_config.visualization == "True":

            cv2.imshow("img", img)

            if step_by_step:
                k = cv2.waitKey(0)
            else:
                k = cv2.waitKey(1)
            if k == ord('r'): 
                rnn_state = model.initial_state(GE.num_envs)
                GE.rnd_goal_collection = get_rnd_goal_collection_vmap(GE.env_keys, GE.batched_envs, GE.width, GE.height, GE.num_free_spaces)
                GE.reset()
                if not random_task:
                    # set states of GE
                    GE.batched_states = GE.batched_states.at[0, 0].set(state[0])
                    GE.batched_states = GE.batched_states.at[0, 1].set(state[1])
                    # set goals of GE
                    GE.batched_goals = GE.batched_goals.at[0, 0].set(goal[0])
                    GE.batched_goals = GE.batched_goals.at[0, 1].set(goal[1])
                    GE.init_batched_states, GE.init_batched_goals = jnp.copy(GE.batched_states), jnp.copy(GE.batched_goals)
                    GE.batched_goal_reached = batch_compute_goal_reached(GE.batched_states, GE.batched_goals)
                    GE.last_batched_goal_reached = jnp.copy(GE.batched_goal_reached)
                    GE.concat_obs = get_ideal_obs_vmap(GE.batched_envs, GE.batched_states, GE.batched_goals, GE.last_batched_goal_reached)
                    concat_obs = GE.concat_obs
                    
                trajectory.clear()
            elif k == ord('p'):
                # 1. take a random action
                k2, _ = jax.random.split(k1)
                k1 = k2
                random_action = jax.random.randint(k1, shape=(GE.num_envs, ), minval=0, maxval=4)
                batched_goal_reached, concat_obs = GE.step(random_action)
                print("random action: ", random_action)

            elif k == ord('n'):
                if random_task:
                    rnn_state = model.initial_state(GE.num_envs)
                    landscape = generate_maze_pool(num_mazes=1, width=10, height=10)
                    landscape = padding_landscapes(landscape, width=12, height=12)
                    GE.set_landscapes(landscape)
                    GE.reset()
                    trajectory.clear()
            elif k == ord('q'):
                exit()
            elif k == ord('t'):
                step_by_step = not step_by_step
            elif k == ord('w'):
                manual_action = 3
            elif k == ord('s'):
                manual_action = 2
            elif k == ord('a'):
                manual_action = 1
            elif k == ord('d'):
                manual_action = 0
            elif k == ord('e') or first_arrival == 1:
                first_arrival = 2
                alter_t = t
                new_landscape = run_editor(landscape[0], GE.batched_states[0], GE.batched_goals[0])
                GE.set_landscapes_only([new_landscape])
                landscape[0] = new_landscape

    print("shape of goal_record: ", np.array(goal_record).shape)
    print("shape of rnn_state: ", rnn_state.shape)
    print("shape of HS_trajectory: ", np.array(HS_trajectory).shape)
    print("shape of IPFs: ", np.array(IPFs).shape)
    print("shape of intr_field: ", np.array(intr_field).shape)
    print("alter_t: ", alter_t)

    # 将 HS_trajectory，IPFs，goal_record，intr_field 保存到文件
    np.save("./logs/altered_HS_trajectory.npy", np.array(HS_trajectory))
    np.save("./logs/altered_IPFs.npy", np.array(IPFs))
    np.save("./logs/altered_goal_record.npy", np.array(goal_record))
    np.save("./logs/altered_intr_field.npy", np.array(intr_field))
    np.save("./logs/alter_t.npy", np.array(alter_t))


if __name__ == "__main__":
    main()